Convolutional LSTM with attention mechanism

Tutorial on how to train a convolutional neural network with a bidirectional LSTM and attention mechansim to predict protein subcellular localization.


In [1]:
# Import all the necessary modules
import os
import sys
os.environ["THEANO_FLAGS"] = "mode=FAST_RUN,optimizer=None,device=cpu,floatX=float32"
sys.path.insert(0,'..')
import numpy as np
import theano
import theano.tensor as T
import lasagne
from confusionmatrix import ConfusionMatrix
from utils import iterate_minibatches, LSTMAttentionDecodeFeedbackLayer
import matplotlib.pyplot as plt
import time
import itertools
%matplotlib inline

Building the network

The first thing that we have to do is to define the network architecture. Here we are going to use an input layer, two convolutional layers, a bidirectional LSTM, an attention layer, a dense layer and an output layer. These are the steps that we are going to follow:

1.- Specify the hyperparameters of the network:


In [2]:
batch_size = 128
seq_len = 400
n_feat = 20
n_hid = 15
n_class = 10
lr = 0.0025
n_filt = 10
drop_prob = 0.5

2.- Define the input variables to our network:


In [3]:
# We use ftensor3 because the protein data is a 3D-matrix in float32 
input_var = T.ftensor3('inputs')
# ivector because the labels is a single dimensional vector of integers
target_var = T.ivector('targets')
# fmatrix because the masks to ignore the padded positions is a 2D-matrix in float32
mask_var = T.fmatrix('masks')
# Dummy data to check the size of the layers during the building of the network
X = np.random.randint(0,10,size=(batch_size,seq_len,n_feat)).astype('float32')
Xmask = np.ones((batch_size,seq_len)).astype('float32')

3.- Define the layers of the network:


In [4]:
# Input layer, holds the shape of the data
l_in = lasagne.layers.InputLayer(shape=(batch_size, None, n_feat), input_var=input_var, name='Input')
print('Input layer: {}'.format(
    lasagne.layers.get_output(l_in, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Mask input layer
l_mask = lasagne.layers.InputLayer(shape=(batch_size, None), input_var=mask_var, name='Mask')
print('Mask layer: {}'.format(
    lasagne.layers.get_output(l_mask, inputs={l_mask: mask_var}).eval({mask_var: Xmask}).shape))

# Shuffle shape to be properly read by the CNN layer
l_shu = lasagne.layers.DimshuffleLayer(l_in, (0,2,1))
print('DimshuffleLayer layer: {}'.format(
    lasagne.layers.get_output(l_shu, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Convolutional layers with different filter size
l_conv_a = lasagne.layers.Conv1DLayer(l_shu, num_filters=n_filt, pad='same', stride=1, 
                                      filter_size=3, nonlinearity=lasagne.nonlinearities.rectify)
print('Convolutional layer size 3: {}'.format(
    lasagne.layers.get_output(l_conv_a, inputs={l_in: input_var}).eval({input_var: X}).shape))

l_conv_b = lasagne.layers.Conv1DLayer(l_shu, num_filters=n_filt, pad='same', stride=1, 
                                      filter_size=5, nonlinearity=lasagne.nonlinearities.rectify)
print('Convolutional layer size 5: {}'.format(
    lasagne.layers.get_output(l_conv_b, inputs={l_in: input_var}).eval({input_var: X}).shape))

# The output is concatenated
l_conc = lasagne.layers.ConcatLayer([l_conv_a, l_conv_b], axis=1)
print('Concatenated convolutional layers: {}'.format(
    lasagne.layers.get_output(l_conc, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Second CNN layer
l_conv_final = lasagne.layers.Conv1DLayer(l_conc, num_filters=n_filt*2, pad='same', 
                                          stride=1, filter_size=3, 
                                          nonlinearity=lasagne.nonlinearities.rectify)
print('Final convolutional layer: {}'.format(
    lasagne.layers.get_output(l_conv_final, inputs={l_in: input_var}).eval({input_var: X}).shape))

l_reshu = lasagne.layers.DimshuffleLayer(l_conv_final, (0,2,1))
print('Second DimshuffleLayer layer: {}'.format(
    lasagne.layers.get_output(l_reshu, inputs={l_in: input_var}).eval({input_var: X}).shape))

l_fwd = lasagne.layers.LSTMLayer(l_reshu, num_units=n_hid, name='LSTMFwd', mask_input=l_mask,
                                 nonlinearity=lasagne.nonlinearities.tanh)
l_bck = lasagne.layers.LSTMLayer(l_reshu, num_units=n_hid, name='LSTMBck', mask_input=l_mask,
                                 backwards=True, nonlinearity=lasagne.nonlinearities.tanh)
print('Forward LSTM layer: {}'.format(
    lasagne.layers.get_output(l_fwd, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))
print('Backward LSTM layer: {}'.format(
    lasagne.layers.get_output(l_bck, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))

# Concatenate both layers
l_conc_lstm = lasagne.layers.ConcatLayer([l_fwd, l_bck], axis=2)

print('Concatenated hidden states: {}'.format(
    lasagne.layers.get_output(l_conc_lstm, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))

l_att = LSTMAttentionDecodeFeedbackLayer(l_conc_lstm, mask_input=l_mask, 
                                         num_units=n_hid*2, aln_num_units=n_hid, 
                                         n_decodesteps=2, name='LSTMAttention')
print('Attention layer: {}'.format(
    lasagne.layers.get_output(l_att, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))

l_last_hid = lasagne.layers.SliceLayer(l_att, indices=-1, axis=1)
print('Last decoding step: {}'.format(
    lasagne.layers.get_output(l_last_hid, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))


# Dense layer with ReLu activation function
l_dense = lasagne.layers.DenseLayer(l_last_hid, num_units=n_hid*2, name="Dense",
                                    nonlinearity=lasagne.nonlinearities.rectify)
print('Dense layer: {}'.format(
    lasagne.layers.get_output(l_dense, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))

# Output layer with a Softmax activation function. Note that we include a dropout layer
l_out = lasagne.layers.DenseLayer(lasagne.layers.dropout(l_dense, p=drop_prob), num_units=n_class, name="Softmax", 
                                  nonlinearity=lasagne.nonlinearities.softmax)
print('Output layer: {}'.format(
    lasagne.layers.get_output(l_out, inputs={l_in: input_var, l_mask: mask_var}).eval(
        {input_var: X, mask_var:Xmask}).shape))


Input layer: (128, 400, 20)
Mask layer: (128, 400)
DimshuffleLayer layer: (128, 20, 400)
Convolutional layer size 3: (128, 10, 400)
Convolutional layer size 5: (128, 10, 400)
Concatenated convolutional layers: (128, 20, 400)
Final convolutional layer: (128, 20, 400)
Second DimshuffleLayer layer: (128, 400, 20)
Forward LSTM layer: (128, 400, 15)
Backward LSTM layer: (128, 400, 15)
Concatenated hidden states: (128, 400, 30)
Attention layer: (128, 2, 30)
Last decoding step: (128, 30)
Dense layer: (128, 30)
Output layer: (128, 10)

4.- Calculate the prediction and network loss for the training set and update the network weights:


In [5]:
# Get output training, deterministic=False is used for training
prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var, l_mask: mask_var}, deterministic=False)

# Calculate the categorical cross entropy between the labels and the prediction
t_loss = T.nnet.categorical_crossentropy(prediction, target_var)

# Training loss
loss = T.mean(t_loss)

# Parameters
params = lasagne.layers.get_all_params([l_out], trainable=True)

# Get the network gradients and perform total norm constraint normalization
all_grads = lasagne.updates.total_norm_constraint(T.grad(loss, params),3)

# Update parameters using ADAM 
updates = lasagne.updates.adam(all_grads, params, learning_rate=lr)

5.- Calculate the prediction and network loss for the validation set:


In [6]:
# Get output validation, deterministic=True is only use for validation
val_prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var, l_mask: mask_var}, deterministic=True)

# Calculate the categorical cross entropy between the labels and the prediction
t_val_loss = lasagne.objectives.categorical_crossentropy(val_prediction, target_var)

# Validation loss 
val_loss = T.mean(t_val_loss)

6.- Build theano functions:


In [ ]:
# Build functions
train_fn = theano.function([input_var, target_var, mask_var], [loss, prediction], updates=updates)
val_fn = theano.function([input_var, target_var, mask_var], [val_loss, val_prediction, l_att.alpha])

Load dataset

Once that the network is built, the next step is to load the training and the validation set


In [8]:
# Load the encoded protein sequences, labels and masks
train = np.load('data/reduced_train.npz')
X_train = train['X_train']
y_train = train['y_train']
mask_train = train['mask_train']
print(X_train.shape)


(2423, 400, 20)

In [9]:
validation = np.load('data/reduced_val.npz')
X_val = validation['X_val']
y_val = validation['y_val']
mask_val = validation['mask_val']
print(X_val.shape)


(635, 400, 20)

Training

Once that the data is ready and the network compiled we can start with the training of the model. Here we define the number of epochs that we want to perform


In [10]:
# Number of epochs
num_epochs = 120

# Lists to save loss and accuracy of each epoch
loss_training = []
loss_validation = []
acc_training = []
acc_validation = []
start_time = time.time()
min_val_loss = float("inf")

# Start training 
for epoch in range(num_epochs):
    
    # Full pass training set
    train_err = 0
    train_batches = 0
    confusion_train = ConfusionMatrix(n_class)

    # Generate minibatches and train on each one of them
    for batch in iterate_minibatches(X_train.astype(np.float32), y_train.astype(np.int32), 
                                     mask_train.astype(np.float32), batch_size, shuffle=True):
        # Inputs to the network
        inputs, targets, in_masks = batch
        # Calculate loss and prediction
        tr_err, predict = train_fn(inputs, targets, in_masks)
        train_err += tr_err
        train_batches += 1
        # Get the predicted class, the one with the maximum likelihood
        preds = np.argmax(predict, axis=-1)
        confusion_train.batch_add(targets, preds)
    
    # Average loss and accuracy
    train_loss = train_err / train_batches
    train_accuracy = confusion_train.accuracy()
    cf_train = confusion_train.ret_mat()

    val_err = 0
    val_batches = 0
    confusion_valid = ConfusionMatrix(n_class)
    
    # Generate minibatches and validate on each one of them, same procedure as before
    for batch in iterate_minibatches(X_val.astype(np.float32), y_val.astype(np.int32), 
                                     mask_val.astype(np.float32), batch_size, shuffle=True):
        inputs, targets, in_masks = batch
        err, predict_val, alphas = val_fn(inputs, targets, in_masks)
        val_err += err
        val_batches += 1
        preds = np.argmax(predict_val, axis=-1)
        confusion_valid.batch_add(targets, preds)

    val_loss = val_err / val_batches
    val_accuracy = confusion_valid.accuracy()
    cf_val = confusion_valid.ret_mat()
    
    loss_training.append(train_loss)
    loss_validation.append(val_loss)
    acc_training.append(train_accuracy)
    acc_validation.append(val_accuracy)
    
    # Save the model parameters at the epoch with the lowest validation loss
    if min_val_loss > val_loss:
        min_val_loss = val_loss
        np.savez('params/CNN-LSTM-Attention_params.npz', *lasagne.layers.get_all_param_values(l_out))
    
    print("Epoch {} of {} time elapsed {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss:\t\t{:.6f}".format(train_loss))
    print("  validation loss:\t\t{:.6f}".format(val_loss))
    print("  training accuracy:\t\t{:.2f} %".format(train_accuracy * 100))
    print("  validation accuracy:\t\t{:.2f} %".format(val_accuracy * 100))


Epoch 1 of 120 time elapsed 11.487s
  training loss:		2.263919
  validation loss:		2.176759
  training accuracy:		16.28 %
  validation accuracy:		14.84 %
Epoch 2 of 120 time elapsed 21.860s
  training loss:		2.154702
  validation loss:		2.059178
  training accuracy:		21.71 %
  validation accuracy:		23.75 %
Epoch 3 of 120 time elapsed 32.363s
  training loss:		2.096224
  validation loss:		2.031837
  training accuracy:		22.74 %
  validation accuracy:		23.44 %
Epoch 4 of 120 time elapsed 42.977s
  training loss:		2.071753
  validation loss:		2.012151
  training accuracy:		22.78 %
  validation accuracy:		34.53 %
Epoch 5 of 120 time elapsed 53.531s
  training loss:		2.031582
  validation loss:		1.946249
  training accuracy:		25.86 %
  validation accuracy:		37.50 %
Epoch 6 of 120 time elapsed 64.223s
  training loss:		1.946548
  validation loss:		1.850198
  training accuracy:		33.43 %
  validation accuracy:		37.97 %
Epoch 7 of 120 time elapsed 75.042s
  training loss:		1.815785
  validation loss:		1.658585
  training accuracy:		38.69 %
  validation accuracy:		40.62 %
Epoch 8 of 120 time elapsed 85.641s
  training loss:		1.702246
  validation loss:		1.557559
  training accuracy:		40.13 %
  validation accuracy:		42.03 %
Epoch 9 of 120 time elapsed 96.411s
  training loss:		1.611471
  validation loss:		1.494310
  training accuracy:		42.68 %
  validation accuracy:		42.19 %
Epoch 10 of 120 time elapsed 107.001s
  training loss:		1.589780
  validation loss:		1.438481
  training accuracy:		41.45 %
  validation accuracy:		43.12 %
Epoch 11 of 120 time elapsed 117.587s
  training loss:		1.528490
  validation loss:		1.414588
  training accuracy:		42.68 %
  validation accuracy:		44.22 %
Epoch 12 of 120 time elapsed 128.318s
  training loss:		1.493151
  validation loss:		1.379526
  training accuracy:		41.94 %
  validation accuracy:		45.62 %
Epoch 13 of 120 time elapsed 138.948s
  training loss:		1.458608
  validation loss:		1.367892
  training accuracy:		44.70 %
  validation accuracy:		45.94 %
Epoch 14 of 120 time elapsed 149.626s
  training loss:		1.419482
  validation loss:		1.319329
  training accuracy:		45.68 %
  validation accuracy:		46.72 %
Epoch 15 of 120 time elapsed 160.423s
  training loss:		1.389282
  validation loss:		1.271499
  training accuracy:		48.11 %
  validation accuracy:		49.69 %
Epoch 16 of 120 time elapsed 171.278s
  training loss:		1.334410
  validation loss:		1.209650
  training accuracy:		49.42 %
  validation accuracy:		51.09 %
Epoch 17 of 120 time elapsed 182.231s
  training loss:		1.275781
  validation loss:		1.161432
  training accuracy:		50.33 %
  validation accuracy:		51.41 %
Epoch 18 of 120 time elapsed 193.396s
  training loss:		1.300343
  validation loss:		1.157342
  training accuracy:		49.14 %
  validation accuracy:		51.56 %
Epoch 19 of 120 time elapsed 204.268s
  training loss:		1.228059
  validation loss:		1.106514
  training accuracy:		51.85 %
  validation accuracy:		55.78 %
Epoch 20 of 120 time elapsed 215.027s
  training loss:		1.157641
  validation loss:		1.037830
  training accuracy:		54.77 %
  validation accuracy:		56.72 %
Epoch 21 of 120 time elapsed 225.841s
  training loss:		1.088943
  validation loss:		0.998614
  training accuracy:		57.85 %
  validation accuracy:		59.22 %
Epoch 22 of 120 time elapsed 236.692s
  training loss:		1.036106
  validation loss:		0.979258
  training accuracy:		59.29 %
  validation accuracy:		61.88 %
Epoch 23 of 120 time elapsed 247.091s
  training loss:		1.023218
  validation loss:		1.006143
  training accuracy:		59.83 %
  validation accuracy:		58.91 %
Epoch 24 of 120 time elapsed 257.966s
  training loss:		0.967990
  validation loss:		0.873826
  training accuracy:		62.54 %
  validation accuracy:		65.16 %
Epoch 25 of 120 time elapsed 268.395s
  training loss:		0.936788
  validation loss:		0.885650
  training accuracy:		63.61 %
  validation accuracy:		60.94 %
Epoch 26 of 120 time elapsed 279.178s
  training loss:		0.910867
  validation loss:		0.834761
  training accuracy:		63.03 %
  validation accuracy:		62.03 %
Epoch 27 of 120 time elapsed 289.549s
  training loss:		0.884867
  validation loss:		0.837570
  training accuracy:		62.71 %
  validation accuracy:		63.91 %
Epoch 28 of 120 time elapsed 299.995s
  training loss:		0.879741
  validation loss:		0.836190
  training accuracy:		64.68 %
  validation accuracy:		67.66 %
Epoch 29 of 120 time elapsed 311.014s
  training loss:		0.832879
  validation loss:		0.784783
  training accuracy:		67.43 %
  validation accuracy:		67.19 %
Epoch 30 of 120 time elapsed 321.447s
  training loss:		0.821918
  validation loss:		0.790819
  training accuracy:		67.35 %
  validation accuracy:		71.88 %
Epoch 31 of 120 time elapsed 332.361s
  training loss:		0.811177
  validation loss:		0.763301
  training accuracy:		69.08 %
  validation accuracy:		72.34 %
Epoch 32 of 120 time elapsed 342.686s
  training loss:		0.802236
  validation loss:		0.765136
  training accuracy:		70.64 %
  validation accuracy:		72.50 %
Epoch 33 of 120 time elapsed 352.918s
  training loss:		0.774178
  validation loss:		0.786716
  training accuracy:		71.13 %
  validation accuracy:		69.53 %
Epoch 34 of 120 time elapsed 363.749s
  training loss:		0.768253
  validation loss:		0.758003
  training accuracy:		71.22 %
  validation accuracy:		74.69 %
Epoch 35 of 120 time elapsed 374.660s
  training loss:		0.760151
  validation loss:		0.754776
  training accuracy:		72.45 %
  validation accuracy:		72.97 %
Epoch 36 of 120 time elapsed 385.470s
  training loss:		0.715217
  validation loss:		0.754309
  training accuracy:		73.15 %
  validation accuracy:		72.97 %
Epoch 37 of 120 time elapsed 395.949s
  training loss:		0.746473
  validation loss:		0.777290
  training accuracy:		70.72 %
  validation accuracy:		70.94 %
Epoch 38 of 120 time elapsed 406.922s
  training loss:		0.766780
  validation loss:		0.734981
  training accuracy:		71.92 %
  validation accuracy:		72.97 %
Epoch 39 of 120 time elapsed 417.811s
  training loss:		0.734888
  validation loss:		0.708020
  training accuracy:		73.36 %
  validation accuracy:		74.84 %
Epoch 40 of 120 time elapsed 428.374s
  training loss:		0.717764
  validation loss:		0.732893
  training accuracy:		73.40 %
  validation accuracy:		73.44 %
Epoch 41 of 120 time elapsed 438.822s
  training loss:		0.713739
  validation loss:		0.722142
  training accuracy:		72.45 %
  validation accuracy:		72.50 %
Epoch 42 of 120 time elapsed 449.267s
  training loss:		0.685794
  validation loss:		0.727481
  training accuracy:		74.38 %
  validation accuracy:		72.97 %
Epoch 43 of 120 time elapsed 459.683s
  training loss:		0.693224
  validation loss:		0.728842
  training accuracy:		74.79 %
  validation accuracy:		75.94 %
Epoch 44 of 120 time elapsed 470.539s
  training loss:		0.670882
  validation loss:		0.690923
  training accuracy:		75.62 %
  validation accuracy:		74.38 %
Epoch 45 of 120 time elapsed 481.042s
  training loss:		0.659008
  validation loss:		0.704600
  training accuracy:		76.19 %
  validation accuracy:		77.03 %
Epoch 46 of 120 time elapsed 492.194s
  training loss:		0.685140
  validation loss:		0.672256
  training accuracy:		74.55 %
  validation accuracy:		78.44 %
Epoch 47 of 120 time elapsed 502.676s
  training loss:		0.647631
  validation loss:		0.720785
  training accuracy:		76.23 %
  validation accuracy:		75.62 %
Epoch 48 of 120 time elapsed 513.155s
  training loss:		0.628003
  validation loss:		0.721225
  training accuracy:		76.85 %
  validation accuracy:		76.88 %
Epoch 49 of 120 time elapsed 523.631s
  training loss:		0.662357
  validation loss:		0.703630
  training accuracy:		75.62 %
  validation accuracy:		75.47 %
Epoch 50 of 120 time elapsed 534.086s
  training loss:		0.644414
  validation loss:		0.683645
  training accuracy:		76.23 %
  validation accuracy:		77.34 %
Epoch 51 of 120 time elapsed 544.589s
  training loss:		0.630702
  validation loss:		0.715532
  training accuracy:		77.01 %
  validation accuracy:		77.19 %
Epoch 52 of 120 time elapsed 555.166s
  training loss:		0.637966
  validation loss:		0.699446
  training accuracy:		76.36 %
  validation accuracy:		78.28 %
Epoch 53 of 120 time elapsed 565.618s
  training loss:		0.654705
  validation loss:		0.694075
  training accuracy:		76.27 %
  validation accuracy:		80.00 %
Epoch 54 of 120 time elapsed 576.081s
  training loss:		0.636719
  validation loss:		0.702866
  training accuracy:		75.53 %
  validation accuracy:		78.91 %
Epoch 55 of 120 time elapsed 586.521s
  training loss:		0.605409
  validation loss:		0.698454
  training accuracy:		78.45 %
  validation accuracy:		77.03 %
Epoch 56 of 120 time elapsed 597.436s
  training loss:		0.592588
  validation loss:		0.670544
  training accuracy:		78.50 %
  validation accuracy:		81.56 %
Epoch 57 of 120 time elapsed 607.937s
  training loss:		0.600179
  validation loss:		0.687857
  training accuracy:		79.07 %
  validation accuracy:		80.00 %
Epoch 58 of 120 time elapsed 618.485s
  training loss:		0.586195
  validation loss:		0.683595
  training accuracy:		79.44 %
  validation accuracy:		76.09 %
Epoch 59 of 120 time elapsed 628.999s
  training loss:		0.577779
  validation loss:		0.684670
  training accuracy:		78.87 %
  validation accuracy:		79.84 %
Epoch 60 of 120 time elapsed 639.481s
  training loss:		0.585924
  validation loss:		0.703338
  training accuracy:		79.15 %
  validation accuracy:		79.53 %
Epoch 61 of 120 time elapsed 649.971s
  training loss:		0.595322
  validation loss:		0.694697
  training accuracy:		78.99 %
  validation accuracy:		80.62 %
Epoch 62 of 120 time elapsed 660.901s
  training loss:		0.570188
  validation loss:		0.668738
  training accuracy:		79.89 %
  validation accuracy:		81.25 %
Epoch 63 of 120 time elapsed 671.757s
  training loss:		0.605759
  validation loss:		0.702803
  training accuracy:		79.61 %
  validation accuracy:		78.75 %
Epoch 64 of 120 time elapsed 682.738s
  training loss:		0.594417
  validation loss:		0.675503
  training accuracy:		79.61 %
  validation accuracy:		80.78 %
Epoch 65 of 120 time elapsed 693.809s
  training loss:		0.565230
  validation loss:		0.666512
  training accuracy:		80.88 %
  validation accuracy:		79.69 %
Epoch 66 of 120 time elapsed 704.412s
  training loss:		0.530147
  validation loss:		0.677351
  training accuracy:		80.96 %
  validation accuracy:		81.41 %
Epoch 67 of 120 time elapsed 715.396s
  training loss:		0.562511
  validation loss:		0.651610
  training accuracy:		80.80 %
  validation accuracy:		82.97 %
Epoch 68 of 120 time elapsed 726.509s
  training loss:		0.543604
  validation loss:		0.636928
  training accuracy:		81.25 %
  validation accuracy:		82.66 %
Epoch 69 of 120 time elapsed 737.015s
  training loss:		0.556453
  validation loss:		0.702337
  training accuracy:		81.37 %
  validation accuracy:		80.31 %
Epoch 70 of 120 time elapsed 747.640s
  training loss:		0.548281
  validation loss:		0.664213
  training accuracy:		81.21 %
  validation accuracy:		81.56 %
Epoch 71 of 120 time elapsed 758.203s
  training loss:		0.512985
  validation loss:		0.642470
  training accuracy:		82.20 %
  validation accuracy:		82.50 %
Epoch 72 of 120 time elapsed 768.764s
  training loss:		0.525078
  validation loss:		0.690799
  training accuracy:		82.89 %
  validation accuracy:		82.34 %
Epoch 73 of 120 time elapsed 779.324s
  training loss:		0.516912
  validation loss:		0.639689
  training accuracy:		82.15 %
  validation accuracy:		82.03 %
Epoch 74 of 120 time elapsed 789.917s
  training loss:		0.503190
  validation loss:		0.667085
  training accuracy:		82.20 %
  validation accuracy:		82.66 %
Epoch 75 of 120 time elapsed 800.931s
  training loss:		0.490152
  validation loss:		0.623988
  training accuracy:		83.80 %
  validation accuracy:		82.81 %
Epoch 76 of 120 time elapsed 811.491s
  training loss:		0.493978
  validation loss:		0.746356
  training accuracy:		82.07 %
  validation accuracy:		78.28 %
Epoch 77 of 120 time elapsed 822.074s
  training loss:		0.555486
  validation loss:		0.676911
  training accuracy:		81.50 %
  validation accuracy:		81.72 %
Epoch 78 of 120 time elapsed 833.581s
  training loss:		0.526313
  validation loss:		0.662417
  training accuracy:		82.61 %
  validation accuracy:		81.41 %
Epoch 79 of 120 time elapsed 844.198s
  training loss:		0.499115
  validation loss:		0.701444
  training accuracy:		82.89 %
  validation accuracy:		82.19 %
Epoch 80 of 120 time elapsed 854.761s
  training loss:		0.537247
  validation loss:		0.670799
  training accuracy:		83.35 %
  validation accuracy:		81.41 %
Epoch 81 of 120 time elapsed 865.327s
  training loss:		0.474299
  validation loss:		0.661066
  training accuracy:		84.13 %
  validation accuracy:		82.50 %
Epoch 82 of 120 time elapsed 876.804s
  training loss:		0.500870
  validation loss:		0.681877
  training accuracy:		83.55 %
  validation accuracy:		81.88 %
Epoch 83 of 120 time elapsed 888.774s
  training loss:		0.484578
  validation loss:		0.668530
  training accuracy:		83.76 %
  validation accuracy:		82.66 %
Epoch 84 of 120 time elapsed 900.662s
  training loss:		0.467297
  validation loss:		0.649309
  training accuracy:		83.92 %
  validation accuracy:		83.12 %
Epoch 85 of 120 time elapsed 913.194s
  training loss:		0.461167
  validation loss:		0.632821
  training accuracy:		84.09 %
  validation accuracy:		82.66 %
Epoch 86 of 120 time elapsed 926.646s
  training loss:		0.476707
  validation loss:		0.637376
  training accuracy:		83.51 %
  validation accuracy:		84.22 %
Epoch 87 of 120 time elapsed 938.945s
  training loss:		0.481624
  validation loss:		0.702212
  training accuracy:		84.50 %
  validation accuracy:		83.28 %
Epoch 88 of 120 time elapsed 949.855s
  training loss:		0.428675
  validation loss:		0.671726
  training accuracy:		85.03 %
  validation accuracy:		83.75 %
Epoch 89 of 120 time elapsed 960.696s
  training loss:		0.442794
  validation loss:		0.669974
  training accuracy:		85.07 %
  validation accuracy:		84.53 %
Epoch 90 of 120 time elapsed 971.291s
  training loss:		0.458515
  validation loss:		0.722932
  training accuracy:		83.63 %
  validation accuracy:		82.19 %
Epoch 91 of 120 time elapsed 981.873s
  training loss:		0.452915
  validation loss:		0.727832
  training accuracy:		84.54 %
  validation accuracy:		79.84 %
Epoch 92 of 120 time elapsed 992.487s
  training loss:		0.436435
  validation loss:		0.688239
  training accuracy:		85.12 %
  validation accuracy:		82.66 %
Epoch 93 of 120 time elapsed 1003.114s
  training loss:		0.428046
  validation loss:		0.670575
  training accuracy:		86.92 %
  validation accuracy:		84.06 %
Epoch 94 of 120 time elapsed 1013.691s
  training loss:		0.428057
  validation loss:		0.676983
  training accuracy:		85.36 %
  validation accuracy:		83.75 %
Epoch 95 of 120 time elapsed 1024.271s
  training loss:		0.428194
  validation loss:		0.699072
  training accuracy:		85.36 %
  validation accuracy:		82.66 %
Epoch 96 of 120 time elapsed 1034.940s
  training loss:		0.416567
  validation loss:		0.762886
  training accuracy:		85.94 %
  validation accuracy:		81.72 %
Epoch 97 of 120 time elapsed 1045.561s
  training loss:		0.409410
  validation loss:		0.688278
  training accuracy:		86.60 %
  validation accuracy:		82.97 %
Epoch 98 of 120 time elapsed 1056.217s
  training loss:		0.459629
  validation loss:		0.711774
  training accuracy:		85.49 %
  validation accuracy:		82.19 %
Epoch 99 of 120 time elapsed 1066.804s
  training loss:		0.423189
  validation loss:		0.740589
  training accuracy:		85.94 %
  validation accuracy:		83.75 %
Epoch 100 of 120 time elapsed 1077.415s
  training loss:		0.424325
  validation loss:		0.668051
  training accuracy:		85.86 %
  validation accuracy:		83.75 %
Epoch 101 of 120 time elapsed 1088.085s
  training loss:		0.432789
  validation loss:		0.655419
  training accuracy:		85.61 %
  validation accuracy:		84.22 %
Epoch 102 of 120 time elapsed 1098.748s
  training loss:		0.387184
  validation loss:		0.703175
  training accuracy:		86.97 %
  validation accuracy:		84.38 %
Epoch 103 of 120 time elapsed 1109.357s
  training loss:		0.392381
  validation loss:		0.653755
  training accuracy:		87.21 %
  validation accuracy:		83.28 %
Epoch 104 of 120 time elapsed 1119.903s
  training loss:		0.417836
  validation loss:		0.667676
  training accuracy:		86.76 %
  validation accuracy:		84.53 %
Epoch 105 of 120 time elapsed 1130.462s
  training loss:		0.369029
  validation loss:		0.701228
  training accuracy:		87.66 %
  validation accuracy:		82.50 %
Epoch 106 of 120 time elapsed 1141.040s
  training loss:		0.375349
  validation loss:		0.731407
  training accuracy:		87.42 %
  validation accuracy:		83.12 %
Epoch 107 of 120 time elapsed 1153.564s
  training loss:		0.389776
  validation loss:		0.714112
  training accuracy:		87.05 %
  validation accuracy:		83.12 %
Epoch 108 of 120 time elapsed 1165.996s
  training loss:		0.382523
  validation loss:		0.673184
  training accuracy:		86.47 %
  validation accuracy:		83.75 %
Epoch 109 of 120 time elapsed 1176.638s
  training loss:		0.387500
  validation loss:		0.725913
  training accuracy:		87.46 %
  validation accuracy:		82.66 %
Epoch 110 of 120 time elapsed 1187.295s
  training loss:		0.408999
  validation loss:		0.702441
  training accuracy:		86.10 %
  validation accuracy:		83.59 %
Epoch 111 of 120 time elapsed 1198.020s
  training loss:		0.416931
  validation loss:		0.728416
  training accuracy:		85.81 %
  validation accuracy:		84.22 %
Epoch 112 of 120 time elapsed 1208.676s
  training loss:		0.398960
  validation loss:		0.654841
  training accuracy:		87.58 %
  validation accuracy:		84.69 %
Epoch 113 of 120 time elapsed 1219.309s
  training loss:		0.364441
  validation loss:		0.708696
  training accuracy:		88.24 %
  validation accuracy:		83.59 %
Epoch 114 of 120 time elapsed 1230.002s
  training loss:		0.377874
  validation loss:		0.705161
  training accuracy:		88.20 %
  validation accuracy:		84.53 %
Epoch 115 of 120 time elapsed 1240.631s
  training loss:		0.392307
  validation loss:		0.721772
  training accuracy:		87.01 %
  validation accuracy:		83.75 %
Epoch 116 of 120 time elapsed 1251.244s
  training loss:		0.379157
  validation loss:		0.743919
  training accuracy:		87.17 %
  validation accuracy:		84.53 %
Epoch 117 of 120 time elapsed 1261.901s
  training loss:		0.361248
  validation loss:		0.704208
  training accuracy:		88.40 %
  validation accuracy:		85.31 %
Epoch 118 of 120 time elapsed 1272.593s
  training loss:		0.381100
  validation loss:		0.776958
  training accuracy:		87.62 %
  validation accuracy:		82.66 %
Epoch 119 of 120 time elapsed 1283.278s
  training loss:		0.407693
  validation loss:		0.754406
  training accuracy:		86.47 %
  validation accuracy:		85.00 %
Epoch 120 of 120 time elapsed 1293.943s
  training loss:		0.407502
  validation loss:		0.754387
  training accuracy:		86.92 %
  validation accuracy:		84.22 %

In [11]:
print("Minimum validation loss: {:.6f}".format(min_val_loss))


Minimum validation loss: 0.623988

Model loss and accuracy

Here we plot the loss and the accuracy for the training and validation set at each epoch.


In [12]:
x_axis = range(num_epochs)
plt.figure(figsize=(8,6))
plt.plot(x_axis,loss_training)
plt.plot(x_axis,loss_validation)
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.legend(('Training','Validation'));



In [13]:
plt.figure(figsize=(8,6))
plt.plot(x_axis,acc_training)
plt.plot(x_axis,acc_validation)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(('Training','Validation'));


Confusion matrix

The confusion matrix allows us to visualize how well is predicted each class and which are the most common misclassifications.


In [14]:
# Plot confusion matrix 
# Code based on http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

plt.figure(figsize=(8,8))
cmap=plt.cm.Blues   
plt.imshow(cf_val, interpolation='nearest', cmap=cmap)
plt.title('Confusion matrix validation set')
plt.colorbar()
tick_marks = np.arange(n_class)
classes = ['Nucleus','Cytoplasm','Extracellular','Mitochondrion','Cell membrane','ER',
           'Chloroplast','Golgi apparatus','Lysosome','Vacuole']

plt.xticks(tick_marks, classes, rotation=60)
plt.yticks(tick_marks, classes)

thresh = cf_val.max() / 2.
for i, j in itertools.product(range(cf_val.shape[0]), range(cf_val.shape[1])):
    plt.text(j, i, cf_val[i, j],
             horizontalalignment="center",
             color="white" if cf_val[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True location')
plt.xlabel('Predicted location');


Sequence attention

Here we plot the attention values for the last validation batch. We can see that the attention differs between the subcellular localizations. The most clear examples are the extracellular proteins, where the attention is at the N-terminal, where the signal peptide is located.


In [15]:
sort_ind = np.argsort(targets)
alphas_1 = alphas[:,1,:][sort_ind]
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,15));
labels_plot = ax1.imshow(targets[sort_ind].reshape(128,1),cmap=plt.get_cmap('Set1'))
ax1.set_aspect(0.3)
ax1.set_axis_off()
cb = plt.colorbar(labels_plot)
labels = np.arange(0,10,1)
loc = labels + .5
cb.set_ticks(loc)
cb.set_ticklabels(classes)
att_plot = ax2.imshow(alphas_1, aspect='auto')
ax2.yaxis.set_visible(False)
plt.tight_layout(pad=25, w_pad=0.5, h_pad=1.0)



In [ ]: